Source code for hysop.tools.spectral_utils
# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math, os
import numpy as np
import sympy as sm
from hysop import main_rank
from hysop.tools.io_utils import IOParams
from hysop.tools.htypes import check_instance, first_not_None, to_tuple
from hysop.tools.numerics import (
is_fp,
is_complex,
complex_to_float_dtype,
float_to_complex_dtype,
)
from hysop.tools.sympy_utils import Expr, Symbol, Dummy, subscript, tensor_symbol
from hysop.constants import BoundaryCondition, BoundaryExtension, TransformType
from hysop.fields.continuous_field import Field, ScalarField, TensorField
[docs]
class SpectralTransformUtils:
"""Class that contains userfull methods for SpectralTransform setup."""
cosine_transforms = (
TransformType.DCT_I,
TransformType.DCT_II,
TransformType.DCT_III,
TransformType.DCT_IV,
TransformType.IDCT_I,
TransformType.IDCT_II,
TransformType.IDCT_III,
TransformType.IDCT_IV,
)
sine_transforms = (
TransformType.DST_I,
TransformType.DST_II,
TransformType.DST_III,
TransformType.DST_IV,
TransformType.IDST_I,
TransformType.IDST_II,
TransformType.IDST_III,
TransformType.IDST_IV,
)
R2R_transforms = cosine_transforms + sine_transforms
R2C_transforms = (TransformType.RFFT,)
C2R_transforms = (TransformType.IRFFT,)
C2C_transforms = (TransformType.FFT, TransformType.IFFT)
forward_transforms = (
TransformType.FFT,
TransformType.RFFT,
TransformType.DST_I,
TransformType.DST_II,
TransformType.DST_III,
TransformType.DST_IV,
TransformType.DCT_I,
TransformType.DCT_II,
TransformType.DCT_III,
TransformType.DCT_IV,
)
backward_transforms = (
TransformType.IFFT,
TransformType.IRFFT,
TransformType.IDST_I,
TransformType.IDST_II,
TransformType.IDST_III,
TransformType.IDST_IV,
TransformType.IDCT_I,
TransformType.IDCT_II,
TransformType.IDCT_III,
TransformType.IDCT_IV,
)
[docs]
@classmethod
def is_cosine(cls, transform):
check_instance(transform, TransformType)
return transform in cls.cosine_transforms
[docs]
@classmethod
def is_sine(cls, transform):
check_instance(transform, TransformType)
return transform in cls.sine_transforms
@classmethod
def is_R2C(cls, transform):
check_instance(transform, TransformType)
return transform in cls.R2C_transforms
[docs]
@classmethod
def is_R2R(cls, transform):
check_instance(transform, TransformType)
return transform in cls.R2R_transforms
[docs]
@classmethod
def is_R2C(cls, transform):
check_instance(transform, TransformType)
return transform in cls.R2C_transforms
[docs]
@classmethod
def is_C2R(cls, transform):
check_instance(transform, TransformType)
return transform in cls.C2R_transforms
[docs]
@classmethod
def is_C2C(cls, transform):
check_instance(transform, TransformType)
return transform in cls.C2C_transforms
[docs]
@classmethod
def is_forward(cls, transform):
check_instance(transform, TransformType)
return transform in cls.forward_transforms
[docs]
@classmethod
def is_backward(cls, transform):
check_instance(transform, TransformType)
return transform in cls.backward_transforms
[docs]
@classmethod
def is_none(cls, transform):
check_instance(transform, TransformType)
return transform is TransformType.NONE
[docs]
@classmethod
def get_transform_offsets(cls, transform):
"""Return left and right transform offsets."""
check_instance(transform, TransformType)
if cls.is_R2R(transform):
if transform is TransformType.DST_I:
return (1, 1)
elif transform is TransformType.DST_III:
return (1, 0)
elif transform is TransformType.DCT_III:
return (0, 1)
elif transform is TransformType.DCT_I:
return (0, 0)
else:
msg = f"Unknown real to real forward transform {transform}."
raise ValueError(msg)
else:
return (0, 0)
[docs]
@classmethod
def get_transform_resolution(cls, resolution, *transforms):
resolution = to_tuple(resolution, cast=int)
check_instance(transforms, tuple, values=TransformType, size=len(resolution))
dim = len(resolution)
shape = []
transform_offsets = []
for i, (tr, si) in enumerate(zip(transforms[::-1], resolution)):
(lo, ro) = cls.get_transform_offsets(tr)
shape.append(si - lo - ro)
transform_offsets.append((lo, ro))
return tuple(shape), tuple(transform_offsets)
[docs]
@classmethod
def compute_wave_numbers(cls, transform, N, L, ftype):
"""Compute wave numbers of a given transform."""
check_instance(transform, TransformType)
check_instance(N, int)
check_instance(L, float)
assert is_fp(ftype)
otype = ftype
if transform is TransformType.FFT:
freqs = 2.0 * np.pi * 1j * np.fft.fftfreq(n=N, d=L / N)
otype = float_to_complex_dtype(ftype)
elif transform is TransformType.RFFT:
freqs = 2.0 * np.pi * 1j * np.fft.rfftfreq(n=N, d=L / N)
otype = float_to_complex_dtype(ftype)
elif transform in (TransformType.DCT_I, TransformType.DST_I):
freqs = np.pi * (np.arange(N, dtype=ftype) + 0.0) / L
elif transform in (TransformType.DCT_III, TransformType.DST_III):
N -= 1
freqs = np.pi * (np.arange(N, dtype=ftype) + 0.5) / L
else:
msg = f"Unknown transform type {transform}."
raise ValueError(msg)
freqs = freqs.astype(otype, copy=True)
return freqs
[docs]
@classmethod
def determine_output_dtype(cls, input_dtype, *transforms):
"""Compute output data type from input data type and list of forward transforms."""
dtype = input_dtype
for tr in transforms:
if cls.is_backward(tr):
msg = "{} is not a forward transform."
msg = msg.format(tr)
raise ValueError(msg)
elif cls.is_none(tr):
continue
elif cls.is_R2R(tr):
msg = f"Expected a floating point data type but got {dtype}."
assert is_fp(dtype), msg
# data type does not change
elif cls.is_R2C(tr):
msg = f"Expected a floating point data type but got {dtype}."
assert is_fp(dtype), msg
dtype = float_to_complex_dtype(dtype)
elif cls.is_C2R(tr):
msg = f"Expected a complex data type but got {dtype}."
assert is_complex(dtype), msg
dtype = complex_to_float_dtype(dtype)
elif cls.is_C2C(tr):
msg = f"Expected a complex data type but got {dtype}."
assert is_complex(dtype), msg
# data type does not change
else:
msg = f"Unknown transform type {tr}."
raise ValueError(msg)
return np.dtype(dtype)
[docs]
@classmethod
def determine_input_dtype(cls, output_dtype, *transforms):
"""Compute input data type from output data type and list of backward transforms."""
backward_transforms = cls.get_inverse_transforms(*transforms)
return cls.determine_output_dtype(output_dtype, *backward_transforms)
[docs]
@classmethod
def parse_expression(cls, expr, replace_pows=True):
"""
Extract all wave_numbers from expression.
If replace_pow is set, all wave_numbers powers will have their own symbol
and are replace in expression (this allows to precompute wavenumber powers).
Returns parsed expression and a set of spectral transforms and
a set of contained wave_numbers.
"""
from hysop.symbolic.spectral import WaveNumber, AppliedSpectralTransform
wave_numbers = set()
transforms = set()
def _extract(expr):
if isinstance(expr, WaveNumber):
wave_numbers.add(expr)
return expr
elif isinstance(expr, AppliedSpectralTransform):
transforms.add(expr)
return expr
elif (
replace_pows
and isinstance(expr, sm.Pow)
and isinstance(expr.args[0], WaveNumber)
and isinstance(expr.args[1], (int, np.integer, sm.Integer))
):
wn = expr.args[0].pow(int(expr.args[1]))
wave_numbers.add(wn)
return wn
elif isinstance(expr, (sm.Symbol, sm.Number)):
return expr
elif isinstance(expr, sm.Expr):
args = ()
for a in expr.args:
args += (_extract(a),)
try:
return expr.func(*args)
except TypeError:
msg = f"\nFATAL ERROR: Failed to rebuild expr {expr}"
msg += f"\n type is {expr.func}"
msg += "\n"
print(msg)
raise
else:
return expr
expr = _extract(expr)
return (expr, transforms, wave_numbers)
[docs]
@classmethod
def generate_wave_number(cls, transform, axis, exponent):
"""Create a new wavenumber. WaveNumbers are registered dummy symbols."""
from hysop.symbolic.spectral import WaveNumber
return WaveNumber(transform=transform, axis=axis, exponent=exponent)
[docs]
@classmethod
def generate_wave_numbers(cls, *transforms):
"""
Generare a list of wave_numbers in transform order.
Axis will match transform position.
"""
wave_numbers = ()
for i, tr in enumerate(transforms):
wave_numbers += (cls.generate_wave_number(tr, i, 1),)
return wave_numbers
[docs]
@classmethod
def transforms_from_field(cls, field, transformed_axes):
"""
Create a tuple of transforms by extracting field boundary conditions.
Note that transforms are returned in natural ordering (ie. contiguous X-axis last).
"""
check_instance(field, ScalarField)
boundaries = tuple(
(lbd, rbd)
for (lbd, rbd) in zip(field.lboundaries_kind, field.rboundaries_kind)
)
transforms = cls.boundaries_to_transforms(boundaries[::-1], transformed_axes)[
::-1
]
return transforms
[docs]
@classmethod
def boundaries_to_transforms(cls, boundaries, transformed_axes):
"""
Return a tuple of TransformType from a tuple of (left_boundaries, right_boundaries).
"""
check_instance(boundaries, tuple, values=tuple)
extensions = cls.boundaries_to_extensions(boundaries)
transforms = cls.extensions_to_transforms(extensions, transformed_axes)
return transforms
[docs]
@classmethod
def boundaries_to_extensions(cls, boundaries):
"""Convert a BoundaryCondition pair tuple to a BoundaryExtension pair tuple."""
check_instance(boundaries, tuple, values=tuple)
valid_boundary_pairs = (
(BoundaryCondition.PERIODIC, BoundaryCondition.PERIODIC),
(
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
),
(
BoundaryCondition.HOMOGENEOUS_NEUMANN,
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
),
(
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
BoundaryCondition.HOMOGENEOUS_NEUMANN,
),
(
BoundaryCondition.HOMOGENEOUS_NEUMANN,
BoundaryCondition.HOMOGENEOUS_NEUMANN,
),
)
extensions = ()
for boundary_pair in boundaries:
if boundary_pair not in valid_boundary_pairs:
msg = "Invalid boundary pair {}, valid ones are\n *{}"
msg = msg.format(
boundary_pair,
"\n *".join(str(vbp) for vbp in valid_boundary_pairs),
)
raise ValueError(msg)
(left_bd, right_bd) = boundary_pair
left_ext = cls.boundary_to_extension(left_bd)
right_ext = cls.boundary_to_extension(right_bd)
extension_pair = (left_ext, right_ext)
extensions += (extension_pair,)
return extensions
[docs]
@classmethod
def boundary_to_extension(cls, boundary):
"""Convert a BoundaryCondition to a BoundaryExtension"""
check_instance(boundary, BoundaryCondition)
if boundary is BoundaryCondition.PERIODIC:
return BoundaryExtension.PERIODIC
elif boundary is BoundaryCondition.HOMOGENEOUS_NEUMANN:
return BoundaryExtension.EVEN
elif boundary is BoundaryCondition.HOMOGENEOUS_DIRICHLET:
return BoundaryExtension.ODD
else:
msg = f"Unknown boundary condition {boundary}."
raise NotImplementedError(msg)
[docs]
@classmethod
def extensions_to_transforms(cls, extensions, transformed_axes, is_complex=False):
"""Convert a BoundaryExtension pair tuple to a TransformType tuple."""
dim = len(extensions)
transforms = ()
for i, extension_pair in enumerate(extensions):
axis = dim - 1 - i
if axis in transformed_axes:
transform = cls.extension_to_transform(
*extension_pair, is_complex=is_complex
)
if is_complex and cls.is_R2R(transform):
raise ValueError(
"Data is complex but you try to apply a real to real transform."
)
is_complex |= transform is TransformType.RFFT
else:
transform = TransformType.NONE
transforms += (transform,)
return transforms
[docs]
@classmethod
def extension_to_transform(cls, left_ext, right_ext, is_complex=False):
"""Convert a BoundaryExtension pair to a TransformType."""
check_instance(left_ext, BoundaryExtension)
check_instance(right_ext, BoundaryExtension)
valid_extension_pairs = (
(BoundaryExtension.PERIODIC, BoundaryExtension.PERIODIC),
(BoundaryExtension.ODD, BoundaryExtension.ODD),
(BoundaryExtension.EVEN, BoundaryExtension.ODD),
(BoundaryExtension.ODD, BoundaryExtension.EVEN),
(BoundaryExtension.EVEN, BoundaryExtension.EVEN),
)
extension_pair = (left_ext, right_ext)
msg = "Invalid domain extension pair {}, valid ones are\n *{}"
msg = msg.format(
extension_pair, "\n *".join(str(vep) for vep in valid_extension_pairs)
)
if extension_pair not in valid_extension_pairs:
raise ValueError(msg)
if left_ext is BoundaryExtension.PERIODIC:
if right_ext is not BoundaryExtension.PERIODIC:
raise ValueError(msg)
if is_complex:
return TransformType.FFT
else:
return TransformType.RFFT
elif left_ext is BoundaryExtension.EVEN:
if right_ext is BoundaryExtension.EVEN:
return TransformType.DCT_I
elif right_ext is BoundaryExtension.ODD:
return TransformType.DCT_III
else:
raise ValueError(msg)
elif left_ext is BoundaryExtension.ODD:
if right_ext is BoundaryExtension.EVEN:
return TransformType.DST_III
elif left_ext is BoundaryExtension.ODD:
return TransformType.DST_I
else:
raise ValueError(msg)
else:
raise ValueError(msg)
[docs]
@classmethod
def get_inverse_transforms(cls, *transforms):
"""Get the inverse TransformType of a TransformType (for all input arguments)."""
known_inverse_transforms = {
TransformType.NONE: TransformType.NONE,
TransformType.FFT: TransformType.IFFT,
TransformType.RFFT: TransformType.IRFFT,
TransformType.DCT_I: TransformType.IDCT_I,
TransformType.DCT_II: TransformType.IDCT_II,
TransformType.DCT_III: TransformType.IDCT_III,
TransformType.DCT_IV: TransformType.IDCT_IV,
TransformType.DST_I: TransformType.IDST_I,
TransformType.DST_II: TransformType.IDST_II,
TransformType.DST_III: TransformType.IDST_III,
TransformType.DST_IV: TransformType.IDST_IV,
TransformType.IFFT: TransformType.FFT,
TransformType.IRFFT: TransformType.RFFT,
TransformType.IDCT_I: TransformType.DCT_I,
TransformType.IDCT_II: TransformType.DCT_II,
TransformType.IDCT_III: TransformType.DCT_III,
TransformType.IDCT_IV: TransformType.DCT_IV,
TransformType.IDST_I: TransformType.DST_I,
TransformType.IDST_II: TransformType.DST_II,
TransformType.IDST_III: TransformType.DST_III,
TransformType.IDST_IV: TransformType.DST_IV,
}
inverse_transforms = ()
for tr in transforms:
if tr not in known_inverse_transforms:
msg = f"Unknown transform {tr}."
raise NotImplementedError(msg)
itr = known_inverse_transforms[tr]
inverse_transforms += (itr,)
return inverse_transforms
[docs]
@classmethod
def get_conjugate_inverse_transforms(cls, *transforms):
"""Get the conjugate inverse TransformType (ie. inverse for odd derivatives)."""
known_conjugate_inverse_transforms = {
TransformType.NONE: TransformType.NONE,
TransformType.FFT: TransformType.IFFT,
TransformType.RFFT: TransformType.IRFFT,
TransformType.DST_I: TransformType.IDCT_I,
TransformType.DST_II: TransformType.IDCT_III,
TransformType.DST_III: TransformType.IDCT_II,
TransformType.DST_IV: TransformType.IDCT_IV,
TransformType.DCT_I: TransformType.IDST_I,
TransformType.DCT_II: TransformType.IDST_III,
TransformType.DCT_III: TransformType.IDST_II,
TransformType.DCT_IV: TransformType.IDST_IV,
TransformType.IFFT: TransformType.FFT,
TransformType.IRFFT: TransformType.RFFT,
TransformType.IDST_I: TransformType.DCT_I,
TransformType.IDST_III: TransformType.DCT_II,
TransformType.IDST_II: TransformType.DCT_III,
TransformType.IDST_IV: TransformType.DCT_IV,
TransformType.IDCT_I: TransformType.DST_I,
TransformType.IDCT_III: TransformType.DST_II,
TransformType.IDCT_II: TransformType.DST_III,
TransformType.IDCT_IV: TransformType.DST_IV,
}
conjugate_inverse_transforms = ()
for tr in transforms:
if tr not in known_conjugate_inverse_transforms:
msg = f"Unknown transform {tr}."
raise NotImplementedError(msg)
citr = known_conjugate_inverse_transforms[tr]
conjugate_inverse_transforms += (citr,)
return conjugate_inverse_transforms
[docs]
def make_multivariate_trigonometric_polynomial(Xl, Xr, lboundaries, rboundaries, N):
"""
Build a tensor product of trigonometric polynomials satisfying boundary conditions on each axis.
lboundaries: np.ndarray of BoundaryCondition
rboundaries: np.ndarray of BoundaryCondition
other parameters: scalar or array_like of the same size as boundary arrays
All parameters are expanded to the size of the length of prescribed boundaries.
See make_trigonometric_polynomial for more informations about parameters.
This method returns a tuple (P,Y) where:
P is a sympy expression representing a multivariate trigonometric polynomials in variables
Y = (y0, y1, ..., yd)
P(Y) = P0(y0) * P1(y1) * ... * Pd(yd)
*d = lboundaries.size-1 = rboundaries.size-1
*P0 is a trigonometric polynomial of order N[0] that satisfies (lboundaries[0], rboundaries[0])
on domain [Xl[0], Xr[0]].
*P1 is a trigonometric polynomial of order N[1] that satisfies (lboundaries[1], rboundaries[1])
on domain [Xl[1], Xr[1]].
.
.
.
*Pd is a trigonometric polynomial of order N[d] that satisfies (lboundaries[d], rboundaries[d])
on domain [Xl[d], Xr[d]].
"""
check_instance(lboundaries, np.ndarray, values=BoundaryCondition, ndim=1, minsize=1)
check_instance(
rboundaries, np.ndarray, values=BoundaryCondition, size=lboundaries.size
)
Xl = to_tuple(Xl)
Xr = to_tuple(Xr)
N = to_tuple(N)
dim = max(len(Xl), len(Xr), len(N), lboundaries.size, rboundaries.size)
def extend(t):
if len(t) == 1:
t *= dim
return t
Xl, Xr, N = extend(Xl), extend(Xr), extend(N)
check_instance(Xl, tuple, size=dim)
check_instance(Xr, tuple, size=dim)
check_instance(N, tuple, values=int, size=dim)
assert lboundaries.size == rboundaries.size == dim
assert all(xl < xr for (xl, xr) in zip(Xl, Xr))
assert all(n >= 1 for n in N)
_, Y = tensor_symbol("y", shape=(dim,))
P = 1
for xl, xr, lb, rb, n, yi in zip(Xl, Xr, lboundaries, rboundaries, N, Y):
Px, xi = make_trigonometric_polynomial(
Xl=xl, Xr=xr, lboundary=lb, rboundary=rb, N=n
)
Py = Px.xreplace({xi: yi})
P *= Py
return (P, Y)
[docs]
def make_multivariate_polynomial(Xl, Xr, lboundaries, rboundaries, N, order):
"""
Build a tensor product of polynomials satisfying boundary conditions on each axis.
lboundaries: np.ndarray of BoundaryCondition
rboundaries: np.ndarray of BoundaryCondition
other parameters: scalar or array_like of the same size as boundary arrays
All parameters are expanded to the size of the length of prescribed boundaries.
See make_polynomial for more informations about parameters.
This method returns a tuple (P,Y) where:
P is a sympy expression representing a multivariate polynomials in variables
Y = (y0, y1, ..., yd)
P(Y) = P0(y0) * P1(y1) * ... * Pd(yd)
*d = lboundaries.size-1 = rboundaries.size-1
*P0 is a polynomial of order N[0] that satisfies (lboundaries[0], rboundaries[0])
on domain [Xl[0], Xr[0]] up to order order[0].
*P1 is a polynomial of order N[1] that satisfies (lboundaries[1], rboundaries[1])
on domain [Xl[1], Xr[1]] up to order order[1].
.
.
.
*Pd is a polynomial of order N[d] that satisfies (lboundaries[d], rboundaries[d])
on domain [Xl[d], Xr[d]] up to order order[d].
"""
check_instance(lboundaries, np.ndarray, values=BoundaryCondition, ndim=1, minsize=1)
check_instance(
rboundaries, np.ndarray, values=BoundaryCondition, size=lboundaries.size
)
Xl = to_tuple(Xl)
Xr = to_tuple(Xr)
N = to_tuple(N)
order = to_tuple(order)
dim = max(len(Xl), len(Xr), len(N), len(order), lboundaries.size, rboundaries.size)
def extend(t):
if len(t) == 1:
t *= dim
return t
Xl, Xr, N, order = extend(Xl), extend(Xr), extend(N), extend(order)
check_instance(Xl, tuple, size=dim)
check_instance(Xr, tuple, size=dim)
check_instance(N, tuple, values=int, size=dim)
check_instance(order, tuple, values=int, size=dim)
assert lboundaries.size == rboundaries.size == dim
assert all(xl < xr for (xl, xr) in zip(Xl, Xr))
assert all(o >= 2 for o in order)
assert all(n > 2 * o for (o, n) in zip(order, N))
_, Y = tensor_symbol("y", shape=(dim,))
P = 1
for xl, xr, lb, rb, n, o, yi in zip(Xl, Xr, lboundaries, rboundaries, N, order, Y):
Px, xi = make_polynomial(Xl=xl, Xr=xr, lboundary=lb, rboundary=rb, N=n, order=o)
Py = Px.xreplace({xi: yi})
P *= Py
return (P, Y)
[docs]
def make_polynomial(Xl, Xr, lboundary, rboundary, N, order):
"""
Build a polynom of order N-1 between on domain [Xl, Xr] that verifies
prescribed left and right boundary conditions up to a certain order.
Conditions:
Xl < Xr
order >= 2
N > 2*order > 4
Valid boundary conditions are:
(PERIODIC, PERIODIC) dPi/dxi(Xl) - dPi/dxi(Xr) = 0
(HDIRICHLET, HDIRICHLET) dPp/dxi(Xl) = dPp/dxi(Xr) = 0 for even derivatives (i%2==0)
(HDIRICHLET, HNEUMANN) mix of the 2nd and 4th conditions
(HNEUMANN, HDIRICHLET) mix of the 2nd and 4th conditions
(HNEUMANN, HNEUMANN) dPi/dxi(Xl) = dPi/dxi(Xr) = 0 for odd derivatives (i%2==1)
Return (P, X) where P is a sympy expression that represent the polynomial and X is the
corresponding sympy.Symbol.
"""
check_instance(lboundary, BoundaryCondition)
check_instance(rboundary, BoundaryCondition)
check_instance(N, int)
check_instance(order, int)
x = sm.Symbol("x")
a, A = tensor_symbol("a", shape=(N,))
def rand(*n):
return 2.0 * (np.random.rand(*n) - 0.5)
K = 2 * order
assert Xl < Xr
assert order >= 2
assert N > K
if N > K:
a[K + 1 :] = rand(N - K - 1)
if (lboundary, rboundary) == ("DIRICHLET", "DIRICHLET"):
a[K] = rand()
else:
a[0] = rand()
P = sum(ai * (x**i) for (i, ai) in enumerate(a))
Pd = [P]
for i in range(K):
Pd.append(Pd[-1].diff(x))
eqs = []
for i in range(order):
if lboundary is BoundaryCondition.PERIODIC:
leq = Pd[2 * i].xreplace({x: Xl}) - Pd[2 * i].xreplace({x: Xr})
elif lboundary is BoundaryCondition.HOMOGENEOUS_NEUMANN:
leq = Pd[2 * i + 1].xreplace({x: Xl})
elif lboundary is BoundaryCondition.HOMOGENEOUS_DIRICHLET:
leq = Pd[2 * i].xreplace({x: Xl})
else:
msg = f"Unknown left boundary condition {lboundary}."
raise NotImplementedError(msg)
if rboundary is BoundaryCondition.PERIODIC:
req = Pd[2 * i + 1].xreplace({x: Xl}) - Pd[2 * i + 1].xreplace({x: Xr})
elif rboundary is BoundaryCondition.HOMOGENEOUS_NEUMANN:
req = Pd[2 * i + 1].xreplace({x: Xr})
elif rboundary is BoundaryCondition.HOMOGENEOUS_DIRICHLET:
req = Pd[2 * i].xreplace({x: Xr})
else:
msg = f"Unknown right boundary condition {lboundary}."
raise NotImplementedError(msg)
if leq.free_symbols:
eqs.append(leq)
if req.free_symbols:
eqs.append(req)
sol = sm.solve(eqs)
P = P.xreplace(sol)
sol.update({ai: np.random.rand() for ai in P.free_symbols.intersection(A)})
P = P.xreplace(sol)
P0 = sm.lambdify(x, sm.horner(P))
X = np.linspace(Xl, Xr, 1000)
m, M = np.min(P0(X)), np.max(P0(X))
P /= M - m
return sm.horner(P), x
[docs]
def make_trigonometric_polynomial(Xl, Xr, lboundary, rboundary, N):
"""
Build a real trigonometric polynomial of order N-1
between on domain [Xl, Xr] that verifies prescribed left and right
boundary conditions.
Conditions:
Xl < Xr
N >= 1
Valid boundary conditions are:
(PERIODIC, PERIODIC)
(HDIRICHLET, HDIRICHLET)
(HDIRICHLET, HNEUMANN)
(HNEUMANN, HDIRICHLET)
(HNEUMANN, HNEUMANN)
Return (P, X) where P is a sympy expression that represent the polynomial and X is the
corresponding sympy.Symbol.
"""
assert N >= 1
assert Xl < Xr
def r(*n):
return 2.0 * (np.random.rand(*n) - 0.5)
x = sm.Symbol("x")
y = (x - Xl) / (Xr - Xl) * (2 * sm.pi)
boundaries = (lboundary, rboundary)
if boundaries == (BoundaryCondition.PERIODIC, BoundaryCondition.PERIODIC):
fn = lambda n: r() * sm.cos(n * y + sm.pi * r()) + r() * sm.sin(
n * y + sm.pi * r()
)
elif boundaries == (
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
):
fn = lambda n: r() * sm.sin(n * y)
elif boundaries == (
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
BoundaryCondition.HOMOGENEOUS_NEUMANN,
):
fn = lambda n: r() * sm.sin((4 * n - 1) / 4.0 * y)
elif boundaries == (
BoundaryCondition.HOMOGENEOUS_NEUMANN,
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
):
fn = lambda n: r() * sm.cos((4 * n - 1) / 4.0 * y)
elif boundaries == (
BoundaryCondition.HOMOGENEOUS_NEUMANN,
BoundaryCondition.HOMOGENEOUS_NEUMANN,
):
fn = lambda n: r() * sm.cos(n * y)
else:
msg = f"Unknown right boundary condition pair {boundaries}."
raise NotImplementedError(msg)
P = sum(fn(i) for i in range(1, N + 1))
P0 = sm.lambdify(x, P)
X = np.linspace(Xl, Xr, 1000)
m, M = np.min(P0(X)), np.max(P0(X))
P *= 2.0 / (M - m)
return (P, x)
[docs]
class EnergyDumper:
def __init__(self, energy_parameter, io_params, fname, **kwds):
from hysop.parameters.buffer_parameter import BufferParameter
super().__init__(**kwds)
check_instance(io_params, IOParams)
check_instance(io_params.filepath, str)
check_instance(io_params.io_leader, int)
filename = io_params.filename
filename = filename.replace("{fname}", fname)
filename = filename.replace("{ite}", "")
assert "{fname}" not in filename
assert "{ite}" not in filename
assert io_params.frequency >= 0
check_instance(energy_parameter, BufferParameter)
assert energy_parameter.size >= 2
assert energy_parameter.dtype in (np.float32, np.float64)
should_write = io_params.io_leader == main_rank
if should_write:
ulp = np.finfo(energy_parameter.dtype).eps ** 4
formatter = {"float_kind": lambda x: f"{x:8.8f}"}
if os.path.isfile(filename):
os.remove(filename)
f = open(filename, "a")
header = (
f"# Evolution of the power spectrum of {energy_parameter.pretty_name}"
)
f.write(header)
f.write(
f"\n# with mean removed (first coefficient) and values clamped to ulp = epsilon^4 = {ulp}"
)
f.write("\n# ITERATION TIME log10(max(POWER_SPECTRUM[1:], ulp)))")
f.flush()
else:
f = None
formatter = None
ulp = None
self.should_write = should_write
self.energy_parameter = energy_parameter
self.io_params = io_params
self.file = f
self.formatter = formatter
self.ulp = ulp
[docs]
def update(self, simulation, wait_for):
if not self.should_write:
return
if not self.io_params.should_dump(simulation=simulation):
return
if wait_for is not None:
wait_for.wait()
assert self.file is not None
energy = self.energy_parameter.value[1:].astype(dtype=np.float64)
energy = np.log10(np.maximum(energy, self.ulp))
values = np.array2string(
energy, formatter=self.formatter, max_line_width=np.inf
)
values = f"\n{simulation.current_iteration} {simulation.t()} {values[1:-1]}"
self.file.write(values)
self.file.flush()
[docs]
@classmethod
def do_compute_energy(cls, *args):
frequencies = set()
for iop in args:
if not isinstance(iop, IOParams):
continue
f = iop.frequency
if f >= 0:
frequencies.add(f)
do_compute = len(frequencies) > 0
frequencies = frequencies if do_compute else None
return do_compute, frequencies
[docs]
@classmethod
def build_energy_parameter(cls, do_compute, field, output_params, prefix):
from hysop.parameters.buffer_parameter import BufferParameter
if do_compute:
pname = f"E{prefix}_{field.name}"
pename = f"E{prefix}_{field.pretty_name}"
param = BufferParameter(
name=pname,
pretty_name=pename,
shape=None,
dtype=None,
initial_value=None,
)
assert param not in output_params, param.name
output_params.update({param})
else:
param = None
return param
[docs]
class EnergyPlotter:
def __init__(
self,
energy_parameters,
io_params,
fname,
fig_title=None,
axes_shape=(1,),
figsize=(15, 9),
basex=10,
basey=10,
**kwds,
):
import matplotlib
import matplotlib.pyplot as plt
from hysop.parameters.buffer_parameter import BufferParameter
super().__init__(**kwds)
check_instance(io_params, IOParams)
check_instance(axes_shape, tuple, minsize=1, allow_none=True)
check_instance(energy_parameters, dict, values=BufferParameter)
should_draw = io_params.visu_leader == main_rank
filename = io_params.filename
filename = filename.replace("{fname}", fname)
assert "{ite}" in filename, filename
assert io_params.frequency >= 0
if should_draw:
fig, axes = plt.subplots(*axes_shape, figsize=figsize)
fig.canvas.mpl_connect("key_press_event", self._on_key_press)
fig.canvas.mpl_connect("close_event", self._on_close)
axes = np.asarray(axes).reshape(axes_shape)
ax = axes[0]
if len(energy_parameters) == 1:
default_fig_title = "Energy parameter {}".format(
next(iter(energy_parameters.values())).pretty_name
)
else:
default_fig_title = "Energy parameters {}".format(
" | ".join(p.pretty_name for p in energy_parameters.values())
)
fig_title = first_not_None(fig_title, default_fig_title)
self.fig_title = fig_title
xmax = 1
lines = ()
for label, p in energy_parameters.items():
assert p.size > 1
Ix = np.arange(1, p.size)
xmax = max(xmax, p.size - 1)
line = ax.plot(
Ix, np.zeros(p.size - 1), "--o", label=label, markersize=3
)[0]
lines += (line,)
xmin = 1
pmax = math.ceil(math.log(xmax, basex))
xmax = basex**pmax if basex == 2 else 1.1 * xmax
ax.set_xlim(xmin, xmax)
ax.set_title("t=None")
ax.set_xlabel("Wavenumber")
ax.set_ylabel("Energy")
ax.set_xscale("log", basex=basex)
ax.set_yscale("log", basey=basey)
ax.legend(loc="upper right")
ax.grid(True, which="major", ls="--", c="k")
ax.grid(True, which="minor", ls=":")
if basex == 2:
ax.xaxis.set_major_formatter(
matplotlib.ticker.FuncFormatter(lambda x, pos: str(int(round(x))))
)
else:
ax.xaxis.set_major_formatter(
matplotlib.ticker.FuncFormatter(
lambda x, pos: rf"$\mathbf{{{basex}^{{{int(round(math.log(x, basex)))}}}}}$"
)
)
ax.yaxis.set_minor_formatter(
matplotlib.ticker.FuncFormatter(
lambda x, pos, ax=ax: rf"$^{{{int(round(math.log(x, basey)))}}}$"
)
)
ax.yaxis.set_major_formatter(
matplotlib.ticker.FuncFormatter(
lambda x, pos: rf"$\mathbf{{{basey}^{{{int(round(math.log(x, basey)))}}}}}$"
)
)
running = True
else:
fig, axes = None, None
lines = None
running = False
ulp = np.finfo(
np.find_common_type([], list(p.dtype for p in energy_parameters.values()))
).eps
ulp = 1e-4 * (ulp**2)
self.fig = fig
self.axes = axes
self.lines = lines
self.filename = filename
self.io_params = io_params
# see https://stackoverflow.com/questions/8257385/automatic-detection-of-display-availability-with-matplotlib
self.has_gui_running = hasattr(fig, "show")
self.should_draw = should_draw
self.basex = basex
self.basey = basey
self.energy_parameters = energy_parameters
self.ulp = ulp
self.plt = plt
[docs]
def update(self, simulation, wait_for):
if not self.should_draw:
return
if not self.io_params.should_dump(simulation=simulation):
return
if wait_for is not None:
wait_for.wait()
ite = simulation.current_iteration
self._update_plot(ite, simulation.t())
self._draw()
self._savefig(ite)
def _update_plot(self, iteration, time):
ymin = np.PINF
ymax = np.NINF
for p, line in zip(self.energy_parameters.values(), self.lines):
energy = p.value[1:]
energy = np.maximum(energy, self.ulp)
ymin = min(ymin, np.min(energy))
ymax = max(ymax, np.max(energy))
line.set_ydata(energy)
basey = self.basey
pmin = int(math.floor(math.log(ymin, basey)))
pmax = int(math.ceil(math.log(ymax, basey)))
if pmax == pmin:
pmax += 1
max_majors = 8
nlevels = pmax - pmin + 1
nsubint = 1
while math.ceil(nlevels / float(nsubint)) > max_majors:
nsubint += 1
pmax = int(math.ceil(pmax / float(nsubint))) * nsubint
pmin = int(math.floor(pmin / float(nsubint))) * nsubint
assert (pmax - pmin) % nsubint == 0
ymin = basey**pmin
ymax = basey**pmax
major_yticks = tuple(basey**i for i in range(pmin, pmax + 1, nsubint))
minor_yticks = tuple(
basey**i for i in range(pmin, pmax + 1, 1) if i % nsubint != 0
)
ax = self.axes[0]
ax.set_ylim(ymin, ymax)
ax.yaxis.set_ticks(major_yticks)
ax.yaxis.set_ticks(minor_yticks, minor=True)
ax.set_title(f"{self.fig_title}\niteration={iteration}, t={time}")
ax.relim()
def _draw(self):
if not self.has_gui_running:
return
self.fig.canvas.draw()
self.fig.show()
self.plt.pause(0.001)
def _savefig(self, iteration):
filename = self.filename.format(ite=f"{iteration:05}")
self.fig.savefig(filename, dpi=self.fig.dpi, bbox_inches="tight")
def _on_close(self, event):
self.has_gui_running = False
def _on_key_press(self, event):
key = event.key
if key == "q":
self.plt.close(self.fig)
self.has_gui_running = False
if __name__ == "__main__":
from hysop.tools.sympy_utils import round_expr
P = make_trigonometric_polynomial(
-1.0,
+1.0,
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
BoundaryCondition.HOMOGENEOUS_NEUMANN,
10,
)[0]
print(round_expr(P, 2))
print()
P = make_polynomial(
-1.0,
+1.0,
BoundaryCondition.HOMOGENEOUS_NEUMANN,
BoundaryCondition.HOMOGENEOUS_DIRICHLET,
10,
2,
)[0]
print(round_expr(P, 2))
print()
lboundaries = np.asarray(
[BoundaryCondition.HOMOGENEOUS_NEUMANN, BoundaryCondition.PERIODIC]
)
rboundaries = np.asarray(
[BoundaryCondition.HOMOGENEOUS_DIRICHLET, BoundaryCondition.PERIODIC]
)
P = make_multivariate_trigonometric_polynomial(
-1.0, +1.0, lboundaries, rboundaries, (3, 5)
)[0]
print(round_expr(P, 2))
print()
P = make_multivariate_polynomial(-1.0, +1.0, lboundaries, rboundaries, (6, 10), 2)[
0
]
print(round_expr(P, 2))
print()